-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
20 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,37 @@ | ||
# PaSST package for HEAR 2021 NeurIPS Challenge Holistic Evaluation of Audio Representations | ||
|
||
|
||
This is an implementation for [Efficient Training of Audio Transformers with Patchout](https://arxiv.org/abs/2110.05069) for HEAR 2021 NeurIPS Challenge | ||
Holistic Evaluation of Audio Representations | ||
|
||
# CUDA version | ||
# CUDA version | ||
|
||
This is an implementation is tested with CUDA version 11.1, and torch installed: | ||
|
||
```shell | ||
pip3 install torch==1.8.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html | ||
``` | ||
|
||
but should work on newer versions of CUDA and torch. | ||
# Installation | ||
|
||
# Installation | ||
|
||
Install the latest version of this repo: | ||
|
||
```shell | ||
pip install -e 'git+https://github.com/kkoutini/[email protected]#egg=hear21passt' | ||
pip install hear21passt | ||
``` | ||
|
||
The models follow the [common API](https://neuralaudio.ai/hear2021-holistic-evaluation-of-audio-representations.html#common-api) of HEAR 21 | ||
The models follow the [common API](https://neuralaudio.ai/hear2021-holistic-evaluation-of-audio-representations.html#common-api) of HEAR 21 | ||
: | ||
|
||
```shell | ||
hear-validator --model hear21passt.base.pt hear21passt.base | ||
hear-validator --model noweights.txt hear21passt.base2levelF | ||
hear-validator --model noweights.txt hear21passt.base2levelmel | ||
``` | ||
|
||
There are three modules available `hear21passt.base`,`hear21passt.base2level`, `hear21passt.base2levelmel` : | ||
|
||
```python | ||
import torch | ||
|
||
|
@@ -42,16 +49,18 @@ print(embed.shape) | |
# Getting the Logits/Class Labels | ||
|
||
You can get the logits (before the sigmoid activation) for the 527 classes of audioset: | ||
|
||
```python | ||
from hear21passt.base import load_model | ||
|
||
model = load_model(mode="logits").cuda() | ||
logits = model(wave_signal) | ||
``` | ||
The class labels indices can be found [here](https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/metadata/class_labels_indices.csv) | ||
|
||
The class labels indices can be found [here](https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/metadata/class_labels_indices.csv) | ||
|
||
You can also use different pre-trained models, for example, the model trained with KD `passt_s_kd_p16_128_ap486`: | ||
|
||
```python | ||
from hear21passt.base import get_basic_model | ||
|
||
|
@@ -62,7 +71,7 @@ logits = model(wave_signal) | |
|
||
# Supporting longer clips | ||
|
||
In case of an input longer than 10 seconds, the `get_scene_embeddings` method compute the average of the embedding of a 10-second overlapping windows. | ||
In case of an input longer than 10 seconds, the `get_scene_embeddings` method compute the average of the embedding of a 10-second overlapping windows. | ||
Depending on the application, it may be useful to use a pre-trained that can extract embeddings from 20 or 30 seconds without averaging. These variant has pre-trained time positional encoding or 20/30 seconds: | ||
|
||
```python | ||
|
@@ -79,8 +88,6 @@ logits = model(wave_signal) | |
|
||
Each pre-trained model has a specific frequency/time positional encoding, it's necessary to select the correct input shape to be able to load the models. The important variables for loading are `input_tdim`, `fstride` and `tstride` to specify the spectrograms time frames, the patches stride over frequency, and patches stride over time, respectively. | ||
|
||
|
||
|
||
```python | ||
import torch | ||
|
||
|
@@ -105,7 +112,7 @@ model.net = get_model_passt("passt_20sec", input_tdim=2000) | |
model.net = get_model_passt("passt_30sec", input_tdim=3000) | ||
``` | ||
|
||
If you provide the wrong spectrograms, the model may fail silently, by generating low-quality embeddings and logits. Make sure you have the correct spectrograms' config for the selected pre-trained models. | ||
If you provide the wrong spectrograms, the model may fail silently, by generating low-quality embeddings and logits. Make sure you have the correct spectrograms' config for the selected pre-trained models. | ||
Models with higher spectrogram resolutions, need to specify the correct spectrogram config: | ||
|
||
```python | ||
|
@@ -134,5 +141,3 @@ model.mel = AugmentMelSTFT(n_mels=128, sr=32000, win_length=800, hopsize=100, n_ | |
|
||
|
||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
|
||
__version__ = "0.0.25" | ||
__version__ = "0.0.26" | ||
|
||
|
||
def embeding_size(hop=50, embeding_size=1000): | ||
embedings = 20 * 60 * (1000 / hop) | ||
return embedings * embeding_size * 4 / (1024 * 1024 * 1024) # float32 in GB | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters