Skip to content

This repo includes the official implementations of "Fine-tune the pretrained ATST model for sound event detection".

License

Notifications You must be signed in to change notification settings

Audio-WestlakeU/ATST-SED

Repository files navigation

ATST-SED

The official implementations of "Fine-tune the pretrained ATST model for sound event detection" (accepted by ICASSP 2024).

This work is highly related to ATST, ATST-Frame. Please check these works if you want to find out the principles of the ATST-SED.

PWC | License: MIT

Paper 🤩 | Issues 😅 | Lab 🙉 | Contact 😘

Introduction

ATST-SED introduces a semi-supervised fine-tuning strategy for better using the pretrained model in SED. ATST-SED vs FrameATST:

  1. FrameATST is a pretrained model that gives high-quality frame-wise audio representations. It works well on various of audio downstream tasks including AudioSet (clip-wise audio tagging) and AudioSetStrong (frame-wise SED).
  2. ATST-SED is an application-specific (in this work, DESED) fine-tuned model that utilized FrameATST model. It works well when: a. We only care several sound events; b. We only have a small amount of labelled data for these events and a relatively larger amount of unlabelled data.
  3. If you are looking for a high-quality inference model for AudioSet/AudioSetStrong, you could refer to the inference code of FrameATST.
  4. If you are looking for a high-quality inference model for DESED or want to train your own SED model with your own data, you could refer to the inference code of ATST-SED
The proposed fine-tuning method for ATST-SED

Updating Notice

  • Quick inference: Add a script for quick inference on a given audio file with any length. Discussions are in this issue.

  • DESED free download for Chinese users: Downloading the DESED dataset is frustrating, we provide a shared link (shared by Chinese cloud disk) for the DESED_dataset.

  • Validation dataset definition: A typo fixed in the validation dataset definition, the explanation is here.

  • Real dataset download: The 7000+ strongly-labelled audio clips extracted from the AudioSet is provided in this issue.

  • Strong val dataset: This dataset meta files are now updated to the repo.

  • About batch sizes: If you change the batch sizes when fine-tuning ATST-Frame (Stage 1/2), you might probably need to change the n_epochs and n_epochs_warmup in the configuration file train/local/confs/stage2.yaml correspondingly. The fine-tuning of ATST-SED is related to the batch sizes, you might not reproduce the reported results when using a smaller batch sizes. The ablation study of the batch size setups is shown in the model performance below.

Comparing with DCASE code

To allow the SED community better understands the codes and implementation details, we developed the algorithm based on the baseline codes of DCASE2023 challenge task 4. Namely, the training progress is build under pytorch-lightning.

we changed

The other parts in the desed_task are left unchange

Get started

  1. To reproduce our experiments, please first ensure you have the full DESED dataset (including 3000+ strongly labelled real audio clips from the AudioSet).

  2. Ensure you have the correct environment. The environment of this code is the same as the DCASE 2023 baseline, please refer to their docs/codes to configure your environment.

  3. Download the pretrained ATST checkpoint (atst_as2M.ckpt). Noted that this checkpoint is fine-tuned by the AudioSet-2M.

  4. Clone the ATST-SED codes by:

git clone https://github.com/Audio-WestlakeU/ATST-SED.git
  1. Install our desed_task package by:
cd ATST-SED
pip install -e .
  1. Change all required paths in train/local/confs/stage1.yaml and train/local/confs/stage2.yaml to your own paths. Noted that the pretrained ATST checkpoint path should be changed in both files.

  2. Start training stage 1 by:

python train_stage1.py --gpus YOUR_DEVICE_ID,

We also supply a pretrained stage 1 ckpt for you to fine-tune directly. Stage_1.ckpt. If you cannot run stage 1 without accm_grad=1, we recommend you to use this checkpoint first.

  1. When finishing the stage 1 training, change the path of the model_init in train/local/confs/stage2.yaml to the stage 1 checkpoint path (we saved top-5 models in both stages of training, you could use the best one as the model initialization in the stage 2, but use any one of the top-5 models should give the similar results).

  2. Start training stage 2 by:

python train_stage2.py --gpus YOUR_DEVICE_ID,

Performance

We report both DESED development set and public evaluation set results. The external set is the extra data extracted from the AudioSet/AudioSetStrong. Please do not mess it with the 3000+ strongly labelled real audio clips from the AudioSet.

Please note that ATST-SED also get top-ranked performance on the public evaluation dataset without using external dataset. But we did not report it in our paper since the limited writing space. Top-1 model used extra weakly-labelled data from AudioSet, we are still mining these part of the data to improve the model performance.

Dataset External set PSDS_1 PSDS_2 ckpt
DCASE dev. set - 0.583 0.810 Stage2_wo_ext.ckpt
DCASE public eval. set - 0.631 0.833 same as the above
DCASE dev. set Used 0.587 0.812 Stage2_w_ext.ckpt
DCASE public eval. set Used 0.631 0.846 same as the above

Two fine-tuned ATST-SED checkpoints, The checkpoint file trained with external dataset is broken, but the one without external data performs similarly. You can download them and use them directly.

If you want to check the performance of the fine-tuned checkpoint:

python train_stage2.py --gpus YOUR_DEVICE_ID, --test_from_checkpoint YOUR_CHECKPOINT_PATH

Ablation on batch sizes:

We report the model performances on the development set with the following setups:

Batch sizes n_epochs n_epochs_warmup accm_grad PSDS_1 PSDS_2
[4, 4, 8, 8] 40 2 \ 0.535 0.784
[8, 8, 16, 16] 80 2 \ 0.562 0.802
[12, 12, 24, 24] 125 5 \ 0.570 0.805
[4, 4, 8, 8] 250 10 6 0.579 0.811

As shown in the table, if you cannot afford the default batch sizes, please make sure that they are in a proper level. Or, we recommend you to use accm_grad hyperparameter in the stage2.yaml to enlarge the batch sizes. However, using accm_grad would also decay the model performances, due to its influcences to the batch norm layer of the CNN model. Comparing with the reported results, you might get a poorer result from 56%~58% in PSDS1 (using last ckpt for validation).

Citation

If you want to cite this paper:

@INPROCEEDINGS{10446159,
  author={Shao, Nian and Li, Xian and Li, Xiaofei},
  booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 
  title={Fine-Tune the Pretrained ATST Model for Sound Event Detection}, 
  year={2024},
  volume={},
  number={},
  pages={911-915},
  keywords={Training;Event detection;Self-supervised learning;Feature extraction;Transformers;Task analysis;Speech processing;sound event detection;self-supervised learning;ATST;fine-tuning pretrained model},
  doi={10.1109/ICASSP48485.2024.10446159}}

About

This repo includes the official implementations of "Fine-tune the pretrained ATST model for sound event detection".

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published