From 372d845cfcaafa0d2c66eae20f68e185fd70f80b Mon Sep 17 00:00:00 2001
From: Benjamin Cretois <benjamin.cretois@nina.no>
Date: Wed, 28 Feb 2024 14:01:01 +0100
Subject: [PATCH] [ADD] Pipeline for fine tuning on DCASE

---
 dcase_fine_tune/CONFIG.yaml     |  53 +++++
 dcase_fine_tune/FTBeats.py      | 129 +++++++++++
 dcase_fine_tune/FTDataModule.py | 158 ++++++++++++++
 dcase_fine_tune/FTevaluate.py   | 374 ++++++++++++++++++++++++++++++++
 dcase_fine_tune/_utils.py       | 143 ++++++++++++
 5 files changed, 857 insertions(+)
 create mode 100644 dcase_fine_tune/CONFIG.yaml
 create mode 100644 dcase_fine_tune/FTBeats.py
 create mode 100644 dcase_fine_tune/FTDataModule.py
 create mode 100644 dcase_fine_tune/FTevaluate.py
 create mode 100644 dcase_fine_tune/_utils.py

diff --git a/dcase_fine_tune/CONFIG.yaml b/dcase_fine_tune/CONFIG.yaml
new file mode 100644
index 0000000..7f1fdaf
--- /dev/null
+++ b/dcase_fine_tune/CONFIG.yaml
@@ -0,0 +1,53 @@
+###########################################
+###########################################
+##### CONFIG FOR DCASE CHALLENGE 2024 #####
+###########################################
+###########################################
+
+##################################
+# PARAMETERS FOR DATA PROCESSING #
+##################################
+data:
+  target_fs: 16000 # used in preprocessing
+  resample: True # used in preprocessing
+  denoise: True # used in preprocessing
+  normalize: True # used in preprocessing
+  frame_length: 25.0 # used in preprocessing
+  tensor_length: 128 # used in preprocessing
+  overlap: 0.5 # used in preprocessing
+  num_mel_bins: 128 # used in preprocessing
+  max_segment_length: 1.0 # used in preprocessing
+  status: validate # used in preprocessing, train or validate or evaluate
+  set_type: "Validation_Set"
+  
+
+#################################
+# PARAMETERS FOR MODEL TRAINING #
+#################################
+# Be sure the parameters match the ones in data processing
+# Otherwise the hash of the folders will be different!!
+
+trainer:
+  max_epochs: 1
+  default_root_dir: /data
+  accelerator: gpu
+  gpus: 1
+  batch_size: 4
+  num_workers: 4
+
+model:
+  lr: 1.0e-05
+  model_path: "/data/models/BEATs/BEATs_iter3_plus_AS2M.pt"
+  specaugment_params: null
+  # specaugment_params:
+  #   application_ratio: 1.0
+  #   time_mask: 40  
+  #   freq_mask: 40 
+
+###################################
+# PARAMETERS FOR MODEL PREDICTION #
+###################################
+predict:
+  wav_save: False
+  overwrite: True
+  tolerance: 0
diff --git a/dcase_fine_tune/FTBeats.py b/dcase_fine_tune/FTBeats.py
new file mode 100644
index 0000000..b525282
--- /dev/null
+++ b/dcase_fine_tune/FTBeats.py
@@ -0,0 +1,129 @@
+import numpy as np
+
+import torch
+from torch import nn, optim
+from torch.nn import functional as F
+from torch.optim.lr_scheduler import MultiStepLR
+from torch.optim.optimizer import Optimizer
+from torchmetrics import Accuracy
+
+import pytorch_lightning as pl
+from pytorch_lightning.utilities.rank_zero import rank_zero_info
+
+from BEATs.BEATs import BEATs, BEATsConfig
+
+class BEATsTransferLearningModel(pl.LightningModule):
+    def __init__(
+        self,
+        num_target_classes: int = 2,
+        lr: float = 1e-3,
+        lr_scheduler_gamma: float = 1e-1,
+        model_path: str = "/model/BEATs_iter3_plus_AS2M.pt",
+        ft_entire_network: bool = False, # Boolean on whether the classifier layer + BEATs should be fine-tuned
+        **kwargs,
+    ) -> None:
+        """TransferLearningModel.
+        Args:
+            lr: Initial learning rate
+        """
+        super().__init__()
+        self.lr = lr
+        self.lr_scheduler_gamma = lr_scheduler_gamma
+        self.num_target_classes = num_target_classes
+        self.ft_entire_network = ft_entire_network
+
+        # Initialise BEATs model
+        self.checkpoint = torch.load(model_path)
+        self.cfg = BEATsConfig(
+            {
+                **self.checkpoint["cfg"],
+                "predictor_class": self.num_target_classes,
+                "finetuned_model": False,
+            }
+        )
+
+        self._build_model()
+
+        self.train_acc = Accuracy(
+            task="multiclass", num_classes=self.num_target_classes
+        )
+        self.valid_acc = Accuracy(
+            task="multiclass", num_classes=self.num_target_classes
+        )
+        self.save_hyperparameters()
+
+    def _build_model(self):
+        # 1. Load the pre-trained network
+        self.beats = BEATs(self.cfg)
+
+        print("LOADING THE PRE-TRAINED WEIGHTS")
+        self.beats.load_state_dict(self.checkpoint["model"])
+
+        # 2. Classifier
+        self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.cfg.predictor_class)
+
+    def extract_features(self, x, padding_mask=None):
+        if padding_mask != None:
+            x, _ = self.beats.extract_features(x, padding_mask)
+        else:
+            x, _ = self.beats.extract_features(x)
+        return x
+
+    def forward(self, x, padding_mask=None):
+        """Forward pass. Return x"""
+
+        # Get the representation
+        if padding_mask != None:
+            x, _ = self.beats.extract_features(x, padding_mask)
+        else:
+            x, _ = self.beats.extract_features(x)
+
+        # Get the logits
+        x = self.fc(x)
+
+        # Mean pool the second layer
+        x = x.mean(dim=1)
+
+        return x
+
+    def loss(self, lprobs, labels):
+        self.loss_func = nn.CrossEntropyLoss()
+        return self.loss_func(lprobs, labels)
+
+    def training_step(self, batch, batch_idx):
+        # 1. Forward pass:
+        x, y_true = batch
+        y_probs = self.forward(x)
+
+        # 2. Compute loss
+        train_loss = self.loss(y_probs, y_true)
+
+        # 3. Compute accuracy:
+        self.log("train_acc", self.train_acc(y_probs, y_true), prog_bar=True)
+
+        return train_loss
+
+    def validation_step(self, batch, batch_idx):
+        # 1. Forward pass:
+        x, y_true = batch
+        y_probs = self.forward(x)
+
+        # 2. Compute loss
+        self.log("val_loss", self.loss(y_probs, y_true), prog_bar=True)
+
+        # 3. Compute accuracy:
+        self.log("val_acc", self.valid_acc(y_probs, y_true), prog_bar=True)
+
+    def configure_optimizers(self):
+        if self.ft_entire_network:
+            optimizer = optim.AdamW(
+                [{"params": self.beats.parameters()}, {"params": self.fc.parameters()}],
+                lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01
+            )  
+        else:
+            optimizer = optim.AdamW(
+                self.fc.parameters(),
+                lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01
+            )  
+
+        return optimizer
\ No newline at end of file
diff --git a/dcase_fine_tune/FTDataModule.py b/dcase_fine_tune/FTDataModule.py
new file mode 100644
index 0000000..982ad56
--- /dev/null
+++ b/dcase_fine_tune/FTDataModule.py
@@ -0,0 +1,158 @@
+from torch.utils.data import Dataset, DataLoader
+from pytorch_lightning import LightningDataModule
+from sklearn.preprocessing import LabelEncoder
+import torch
+import pandas as pd
+
+
+class AudioDatasetDCASE(Dataset):
+    def __init__(
+        self,
+        data_frame,
+        label_dict=None,
+    ):
+        self.data_frame = data_frame
+        self.label_encoder = LabelEncoder()
+        if label_dict is not None:
+            self.label_encoder.fit(list(label_dict.keys()))
+            self.label_dict = label_dict
+        else:
+            self.label_encoder.fit(self.data_frame["category"])
+            self.label_dict = dict(
+                zip(
+                    self.label_encoder.classes_,
+                    self.label_encoder.transform(self.label_encoder.classes_),
+                )
+            )
+
+    def __len__(self):
+        return len(self.data_frame)
+
+    def get_labels(self):
+        labels = []
+
+        for i in range(0, len(self.data_frame)):
+            label = self.data_frame.iloc[i]["category"]
+            label = self.label_encoder.transform([label])[0]
+            labels.append(label)
+
+        return labels
+
+    def __getitem__(self, idx):
+        input_feature = torch.Tensor(self.data_frame.iloc[idx]["feature"])
+        label = self.data_frame.iloc[idx]["category"]
+
+        # Encode label as integer
+        label = self.label_encoder.transform([label])[0]
+
+        return input_feature, label
+
+    def get_label_dict(self):
+        return self.label_dict
+
+class DCASEDataModule(LightningDataModule):
+    def __init__(
+        self,
+        data_frame= pd.DataFrame,
+        batch_size = 4,
+        num_workers = 4,
+        tensor_length = 128,
+        **kwargs
+    ):
+        super().__init__(**kwargs)
+        self.data_frame = data_frame
+        self.batch_size=batch_size
+        self.num_workers=num_workers
+        self.tensor_length = tensor_length
+        self.setup()
+
+    def setup(self, stage=None):
+        # load data
+        self.complete_dataset = AudioDatasetDCASE(
+            data_frame=self.data_frame,
+        )
+
+    def train_dataloader(self):
+        train_loader = DataLoader(self.complete_dataset, 
+                                  batch_size=self.batch_size, 
+                                  num_workers=self.num_workers, 
+                                  pin_memory=False, 
+                                  collate_fn=self.collate_fn)
+        return train_loader
+
+    def get_label_dict(self):
+        label_dic = self.complete_dataset.get_label_dict()
+        return label_dic
+    
+    def collate_fn(
+            self, input_data
+    ):
+        true_class_ids = list({x[1] for x in input_data})
+        new_input = []
+        for x in input_data:
+            if x[0].shape[1] > self.tensor_length:
+                rand_start = torch.randint(
+                    0, x[0].shape[1] - self.tensor_length, (1,)
+                )
+                new_input.append(
+                    (x[0][:, rand_start : rand_start + self.tensor_length], x[1])
+                )
+            else:
+                new_input.append(x)
+
+        all_images = torch.cat([x[0].unsqueeze(0) for x in new_input])
+        all_labels = (torch.tensor([true_class_ids.index(x[1]) for x in input_data]))
+
+        return (all_images, all_labels)
+    
+    
+class predictLoader():
+    def __init__(
+        self,
+        data_frame= pd.DataFrame,
+        batch_size = 1,
+        num_workers = 4,
+        tensor_length = 128
+    ):
+        self.data_frame = data_frame
+        self.batch_size=batch_size
+        self.num_workers=num_workers
+        self.tensor_length = tensor_length
+        self.setup()
+
+    def setup(self, stage=None):
+        # load data
+        self.complete_dataset = AudioDatasetDCASE(
+            data_frame=self.data_frame,
+        )
+
+    def pred_dataloader(self):
+        pred_loader = DataLoader(self.complete_dataset, 
+                                  batch_size=self.batch_size, 
+                                  num_workers=self.num_workers, 
+                                  pin_memory=False, 
+                                  collate_fn=self.collate_fn)
+        return pred_loader
+
+
+    def collate_fn(
+            self, input_data
+    ):
+        true_class_ids = list({x[1] for x in input_data})
+        new_input = []
+        for x in input_data:
+            if x[0].shape[1] > self.tensor_length:
+                rand_start = torch.randint(
+                    0, x[0].shape[1] - self.tensor_length, (1,)
+                )
+                new_input.append(
+                    (x[0][:, rand_start : rand_start + self.tensor_length], x[1])
+                )
+            else:
+                new_input.append(x)
+
+        all_images = torch.cat([x[0].unsqueeze(0) for x in new_input])
+        all_labels = (torch.tensor([true_class_ids.index(x[1]) for x in input_data]))
+
+        return (all_images, all_labels)
+
diff --git a/dcase_fine_tune/FTevaluate.py b/dcase_fine_tune/FTevaluate.py
new file mode 100644
index 0000000..ef30b41
--- /dev/null
+++ b/dcase_fine_tune/FTevaluate.py
@@ -0,0 +1,374 @@
+import os
+import csv
+import hashlib
+import json
+import tqdm
+import glob
+from datetime import datetime
+import shutil
+from copy import deepcopy
+
+import pytorch_lightning as pl
+import pandas as pd
+import numpy as np
+
+import torch
+from torch.utils.data import DataLoader
+
+from dcase_fine_tune.FTBeats import BEATsTransferLearningModel
+from dcase_fine_tune.FTDataModule import AudioDatasetDCASE, DCASEDataModule, predictLoader
+from dcase_fine_tune._utils import write_wav, write_results, merge_preds, to_dataframe, construct_path, compute_scores
+
+import hydra
+from omegaconf import DictConfig, OmegaConf
+
+def finetune_model(
+    model_path,
+    datamodule_class,
+    max_epochs,
+    num_sanity_val_steps=0,
+):
+    # create the lightning trainer object
+    trainer = pl.Trainer(
+        max_epochs=max_epochs,
+        enable_model_summary=False,
+        num_sanity_val_steps=num_sanity_val_steps,
+        deterministic=True,
+        gpus=1,
+        auto_select_gpus=True,
+        callbacks=[
+            pl.callbacks.LearningRateMonitor(logging_interval="step"),
+            pl.callbacks.EarlyStopping(monitor="train_acc", mode="max", patience=max_epochs),
+        ],
+        default_root_dir="logs/",
+        enable_checkpointing=False
+    )
+
+    # create the model object
+    model = BEATsTransferLearningModel(model_path=model_path)
+
+    # train the model
+    trainer.fit(model, datamodule=datamodule_class)
+
+    return model
+
+def predict_label(cfg, model, loader):
+        
+    model = model.to("cuda")
+    
+    # Get the embeddings, the beginning and end of the segment!
+    pred_labels = []
+    labels = []
+    begins = []
+    ends = []
+
+    for i, data in enumerate(tqdm(loader)):
+        # Get the embeddings for the query
+        feature, label = data
+        feature = feature.to("cuda")
+
+        # Calculate beginTime and endTime for each segment
+        # We multiply by 100 to get the time in seconds
+        if i == 0:
+            begin = i / 1000
+            end = cfg["data"]["tensor_length"] * cfg["data"]["frame_shift"] / 1000
+        else:
+            begin = i * cfg["data"]["tensor_length"] * cfg["data"]["frame_shift"] * cfg["data"]["overlap"] / 1000
+            end = begin + cfg["data"]["tensor_length"] * cfg["data"]["frame_shift"] / 1000
+
+        # Get the scores:
+        classification_scores = model.forward(feature)
+        predicted_labels = torch.max(classification_scores, 0)[1] 
+
+        # To numpy array
+        predicted_labels = predicted_labels.detach().to("cpu").numpy()
+        label = label.detach().to("cpu").numpy()
+
+        # Return the labels, begin and end of the detection
+        pred_labels.append(predicted_labels)
+        labels.append(label)
+        begins.append(begin)
+        ends.append(end)
+
+        # Return
+        return pred_labels, labels, begins, ends
+
+def train_predict(
+    cfg,
+    meta_df,
+    support_spectrograms,
+    support_labels,
+    query_spectrograms,
+    query_labels,
+    target_path="/data"
+):
+    
+    # Get the filename and the frame_shift for the particular file
+    filename = os.path.basename(support_spectrograms).split("data_")[1].split(".")[0]
+    frame_shift = meta_df.loc[filename, "frame_shift"]
+
+    print("[INFO] PROCESSING {}".format(filename))
+
+    # check labels and spectograms all from same file
+    assert filename in support_labels
+    assert filename in query_spectrograms
+    assert filename in query_labels
+
+    df_support = to_dataframe(support_spectrograms, support_labels)
+    supportLoader = DCASEDataModule(data_frame=df_support, 
+                                    batch_size=cfg["trainer"]["batch_size"], 
+                                    num_workers=cfg["trainer"]["num_workers"],
+                                    tensor_length=cfg["data"]["tensor_length"])
+
+    label_dic = supportLoader.get_label_dict()
+
+    #########################
+    # FINE TUNING THE MODEL #
+    #########################
+
+    # Train the model with the support data
+    print("[INFO] TRAINING THE MODEL FOR {}".format(filename))
+    model = finetune_model(model_path=cfg["model"]["model_path"], 
+                           datamodule_class=supportLoader, 
+                           max_epochs=cfg["trainer"]["max_epochs"]
+    )
+
+    #################################
+    # PREDICTING USING THE FT MODEL #
+    #################################
+
+    ### Get the query dataset ###
+    df_query = to_dataframe(query_spectrograms, query_labels)
+    queryLoader = predictLoader(data_frame=df_query,
+                                batch_size=1,
+                                num_workers=cfg["trainer"]["num_workers"],
+                                tensor_length=cfg["data"]["tensor_length"]).pred_dataloader()
+    
+    predicted_labels, labels, begins, ends = predict_label(cfg=cfg, model=model, loader=queryLoader)
+
+    ######################
+    # COMPUTE THE SCORES #
+    ######################
+
+    # Compute the scores for the analysed file -- just as information
+    acc, recall, precision, f1score = compute_scores(
+        predicted_labels=predicted_labels,  #updated_labels,
+        gt_labels=labels,
+    )
+    with open(
+        os.path.join(target_path, "summary.csv"),
+        "a",
+        newline="",
+        encoding="utf-8",
+    ) as my_file:
+        wr = csv.writer(my_file, delimiter=",")
+        wr.writerow([filename, acc, recall, precision, f1score])
+
+    # Get the results in a dataframe
+    df_result = write_results(predicted_labels, begins, ends)
+
+    # Convert the binary PredLabels (0,1) into POS or NEG string --> WE DO THAT BECAUSE LABEL ENCODER FOR EACH FILE CAN DIFFER
+    # invert the key-value pairs of the dictionary using a dictionary comprehension
+    label_dict_inv = {v: k for k, v in label_dic.items()}
+
+    # use the map method to replace the values in the "PredLabels" column
+    df_result["PredLabels"] = df_result["PredLabels"].map(label_dict_inv)
+    df_result_raw = df_result.copy()
+    df_result_raw["gt_labels"] = labels
+    df_result_raw["filename"] = filename
+    # Filter only the POS results
+    result_POS = df_result[df_result["PredLabels"] == "POS"].drop(
+        ["PredLabels"], axis=1
+    )
+
+    result_POS_merged = merge_preds(
+        df=result_POS,
+        tolerence=cfg["tolerance"],
+        tensor_length=cfg["data"]["tensor_length"],
+    )
+
+    # Add the filename
+    result_POS_merged["Audiofilename"] = filename + ".wav"
+
+    # Place filename as first column
+    f = result_POS_merged.pop("Audiofilename")
+    result_POS_merged.insert(0, "Audiofilename", f)
+
+    # Return the dataset
+    print("[INFO] {} PROCESSED".format(filename))
+    return (
+        result_POS_merged,
+        predicted_labels,
+        labels,
+        df_result_raw,
+    )
+
+@hydra.main(version_base=None, config_path="/app/dcase_fine_tune", config_name="CONFIG.yaml")
+def main(cfg: DictConfig):
+
+    # Get training config
+    version_path = os.path.dirname(os.path.dirname(cfg["model"]["model_path"]))
+    version_name = os.path.basename(version_path)
+
+    # Simplify the creation of my_hash_dict using dictionary comprehension
+    keys = ["resample", "denoise", "normalize", "frame_length", "tensor_length",
+            "set_type", "overlap", "num_mel_bins", "max_segment_length"]
+    my_hash_dict = {k: cfg["data"][k] for k in keys}
+
+    # Conditionally add 'target_fs' if 'resample' is True
+    if cfg["data"]["resample"]:
+        my_hash_dict["target_fs"] = cfg["data"]["target_fs"]
+
+    # Generate hash directory name
+    hash_dir_name = hashlib.sha1(json.dumps(my_hash_dict, sort_keys=True).encode()).hexdigest()
+
+    # Base directory for data
+    base_data_path = "/data/DCASEfewshot"
+
+    # get meta, support and query paths
+    support_data_path = os.path.join(
+        "/data/DCASEfewshot",
+        cfg["data"]["status"],
+        hash_dir_name,
+        "audio",
+        "support_data_*.npz",
+    )
+    support_labels_path = os.path.join(
+        "/data/DCASEfewshot",
+        cfg["data"]["status"],
+        hash_dir_name,
+        "audio",
+        "support_labels_*.npy",
+    )
+    query_data_path = os.path.join(
+        "/data/DCASEfewshot",
+        cfg["data"]["status"],
+        hash_dir_name,
+        "audio",
+        "query_data_*.npz",
+    )
+    query_labels_path = os.path.join(
+        "/data/DCASEfewshot",
+        cfg["data"]["status"],
+        hash_dir_name,
+        "audio",
+        "query_labels_*.npy",
+    )
+    meta_df_path = os.path.join(
+        "/data/DCASEfewshot", cfg["data"]["status"], hash_dir_name, "audio", "meta.csv"
+    )
+
+    # set target path
+    target_path = os.path.join(
+        "/data/DCASEfewshot",
+        cfg["data"]["status"],
+        hash_dir_name,
+        "results",
+        "fine_tuned",
+        version_name,
+        "results_{date:%Y%m%d_%H%M%S}".format(date=datetime.now()),
+    )
+    if cfg["predict"]["overwrite"]:
+        if os.path.exists(target_path):
+            shutil.rmtree(target_path)
+
+    if not os.path.exists(target_path):
+        os.makedirs(target_path)
+
+    # save params for eval
+    param = deepcopy(cfg)
+    # Convert the DictConfig object to a standard Python dictionary
+    param = OmegaConf.to_container(param, resolve=True)
+    
+    with open(os.path.join(target_path, "param.json"), "w") as fp:
+        json.dump(param, fp)
+
+    # Get all the files from the Validation / Evaluation set - when save wav option -
+    if cfg["predict"]["wav_save"]:
+        path = os.path.join("/data/DCASE/Development_Set", my_hash_dict["set_type"])
+        files = glob.glob(path + "/**/*.wav", recursive=True)
+
+    # List all the SUPPORT files (both labels and spectrograms)
+    support_all_spectrograms = glob.glob(support_data_path, recursive=True)
+    support_all_labels = glob.glob(support_labels_path, recursive=True)
+
+    # List all the QUERY files
+    query_all_spectrograms = glob.glob(query_data_path, recursive=True)
+    query_all_labels = glob.glob(query_labels_path, recursive=True)
+
+    # ensure lists are ordered the same
+    support_all_spectrograms.sort()
+    support_all_labels.sort()
+    query_all_spectrograms.sort()
+    query_all_labels.sort()
+
+    # Open the meta.csv containing the frame_shift for each file
+    meta_df = pd.read_csv(meta_df_path, names=["frame_shift", "filename"])
+    meta_df.drop_duplicates(inplace=True, keep="first")
+    meta_df.set_index("filename", inplace=True)
+
+    # Dataset to store all the results
+    results = pd.DataFrame()
+    results_raw = pd.DataFrame()
+
+    # Run the main script
+    for support_spectrograms, support_labels, query_spectrograms, query_labels in zip(
+        support_all_spectrograms,
+        support_all_labels,
+        query_all_spectrograms,
+        query_all_labels,
+    ):
+        filename = (
+            os.path.basename(support_spectrograms).split("data_")[1].split(".")[0]
+        )
+        (
+            result,
+            pred_labels,
+            gt_labels,
+            result_raw,
+        ) = train_predict(
+            param,
+            meta_df,
+            support_spectrograms,
+            support_labels,
+            query_spectrograms,
+            query_labels,
+        )
+
+        results = results.append(result)
+        results_raw = results_raw.append(result_raw)
+
+        # Write the wav file if specified
+        if cfg["predict"]["wav_save"]:
+            write_wav(
+                files,
+                param,
+                gt_labels,
+                pred_labels,
+                target_fs=cfg["data"]["target_fs"],
+                target_path=target_path,
+                frame_shift=meta_df.loc[filename, "frame_shift"],
+                support_spectrograms=support_spectrograms
+            )
+
+    # Return the final product
+
+    results.to_csv(
+        os.path.join(
+            target_path,
+            "eval_out.csv",
+        ),
+        index=False,
+    )
+    results_raw.to_csv(
+        os.path.join(
+            target_path,
+            "raw_eval_out.csv",
+        ),
+        index=False,
+    )
+    print("Evaluation Finished. Results saved to " + target_path)
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/dcase_fine_tune/_utils.py b/dcase_fine_tune/_utils.py
new file mode 100644
index 0000000..be92c8a
--- /dev/null
+++ b/dcase_fine_tune/_utils.py
@@ -0,0 +1,143 @@
+import os
+import librosa
+import numpy as np
+import pandas as pd
+
+from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
+
+def write_results(predicted_labels, begins, ends):
+    df_out = pd.DataFrame(
+        {
+            "Starttime": begins,
+            "Endtime": ends,
+            "PredLabels": predicted_labels,
+        }
+    )
+
+    return df_out
+
+def write_wav(
+    files,
+    cfg,
+    gt_labels,
+    pred_labels,
+    distances_to_pos,
+    z_scores_pos,
+    target_fs=16000,
+    target_path=None,
+    frame_shift=1,
+    support_spectrograms=None
+):
+    from scipy.io import wavfile
+
+    # Some path management
+    filename = (
+        os.path.basename(support_spectrograms).split("data_")[1].split(".")[0] + ".wav"
+    )
+    # Return the final product
+    output = os.path.join(target_path, filename)
+
+    # Find the filepath for the file being analysed
+    for f in files:
+        if os.path.basename(f) == filename:
+            print(os.path.basename(f))
+            print(filename)
+            arr, _ = librosa.load(f, sr=target_fs, mono=True)
+            break
+
+    # Expand the dimensions
+    gt_labels = np.repeat(
+        np.squeeze(gt_labels, axis=1).T,
+        int(
+            cfg["data"]["tensor_length"]
+            * cfg["data"]["overlap"]
+            * target_fs
+            * frame_shift
+            / 1000
+        ),
+    )
+    pred_labels = np.repeat(
+        pred_labels.T,
+        int(
+            cfg["data"]["tensor_length"]
+            * cfg["data"]["overlap"]
+            * target_fs
+            * frame_shift
+            / 1000
+        ),
+    )
+    distances_to_pos = np.repeat(
+        distances_to_pos.T,
+        int(
+            cfg["data"]["tensor_length"]
+            * cfg["data"]["overlap"]
+            * target_fs
+            * frame_shift
+            / 1000
+        ),
+    )
+    z_scores_pos = np.repeat(
+        z_scores_pos.T,
+        int(
+            cfg["data"]["tensor_length"]
+            * cfg["data"]["overlap"]
+            * target_fs
+            * frame_shift
+            / 1000
+        ),
+    )
+
+    # pad with zeros
+    gt_labels = np.pad(
+        gt_labels, (0, len(gt_labels) - len(arr)), "constant", constant_values=(0,)
+    )
+    pred_labels = np.pad(
+        pred_labels, (0, len(pred_labels) - len(arr) ), "constant", constant_values=(0,)
+    )
+    distances_to_pos = np.pad(
+        distances_to_pos,
+        (0, len(distances_to_pos) - len(arr)),
+        "constant",
+        constant_values=(0,),
+    )
+    z_scores_pos = np.pad(
+        z_scores_pos,
+      (0, len(z_scores_pos) - len(arr)),
+        "constant",
+        constant_values=(0,),
+    )
+
+    # Write the results
+    result_wav = np.vstack(
+        (arr, gt_labels, pred_labels, distances_to_pos / 10, z_scores_pos)
+    )
+    wavfile.write(output, target_fs, result_wav.T)
+
+def merge_preds(df, tolerence, tensor_length):
+    df["group"] = (
+        df["Starttime"] > (df["Endtime"] + tolerence * tensor_length).shift().cummax()
+    ).cumsum()
+    result = df.groupby("group").agg({"Starttime": "min", "Endtime": "max"})
+    return result
+
+def to_dataframe(features, labels):
+    """Load the saved array and map the features and labels into a single dataframe"""
+    input_features = np.load(features)
+    labels = np.load(labels)
+    list_input_features = [input_features[key] for key in input_features.files]
+    df = pd.DataFrame({"feature": list_input_features, "category": labels})
+    return df
+
+def compute_scores(predicted_labels, gt_labels):
+    acc = accuracy_score(gt_labels, predicted_labels)
+    recall = recall_score(gt_labels, predicted_labels)
+    f1score = f1_score(gt_labels, predicted_labels)
+    precision = precision_score(gt_labels, predicted_labels)
+    print(f"Accurracy: {acc}")
+    print(f"Recall: {recall}")
+    print(f"precision: {precision}")
+    print(f"F1 score: {f1score}")
+    return acc, recall, precision, f1score
+
+def construct_path(base_dir, status, hash_dir_name, file_type, file_pattern):
+    return os.path.join(base_dir, status, hash_dir_name, "audio", f"{file_type}.{file_pattern}")
\ No newline at end of file