diff --git a/README.md b/README.md index ffaf889..8e2b971 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,29 @@ # PaSST package for HEAR 2021 NeurIPS Challenge Holistic Evaluation of Audio Representations - This is an implementation for [Efficient Training of Audio Transformers with Patchout](https://arxiv.org/abs/2110.05069) for HEAR 2021 NeurIPS Challenge Holistic Evaluation of Audio Representations -# CUDA version +# CUDA version + This is an implementation is tested with CUDA version 11.1, and torch installed: + ```shell pip3 install torch==1.8.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html ``` + but should work on newer versions of CUDA and torch. -# Installation + +# Installation + Install the latest version of this repo: + ```shell -pip install -e 'git+https://github.com/kkoutini/passt_hear21@0.0.25#egg=hear21passt' +pip install hear21passt ``` -The models follow the [common API](https://neuralaudio.ai/hear2021-holistic-evaluation-of-audio-representations.html#common-api) of HEAR 21 +The models follow the [common API](https://neuralaudio.ai/hear2021-holistic-evaluation-of-audio-representations.html#common-api) of HEAR 21 : + ```shell hear-validator --model hear21passt.base.pt hear21passt.base hear-validator --model noweights.txt hear21passt.base2levelF @@ -25,6 +31,7 @@ hear-validator --model noweights.txt hear21passt.base2levelmel ``` There are three modules available `hear21passt.base`,`hear21passt.base2level`, `hear21passt.base2levelmel` : + ```python import torch @@ -42,16 +49,18 @@ print(embed.shape) # Getting the Logits/Class Labels You can get the logits (before the sigmoid activation) for the 527 classes of audioset: + ```python from hear21passt.base import load_model model = load_model(mode="logits").cuda() logits = model(wave_signal) ``` -The class labels indices can be found [here](https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/metadata/class_labels_indices.csv) +The class labels indices can be found [here](https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/metadata/class_labels_indices.csv) You can also use different pre-trained models, for example, the model trained with KD `passt_s_kd_p16_128_ap486`: + ```python from hear21passt.base import get_basic_model @@ -62,7 +71,7 @@ logits = model(wave_signal) # Supporting longer clips -In case of an input longer than 10 seconds, the `get_scene_embeddings` method compute the average of the embedding of a 10-second overlapping windows. +In case of an input longer than 10 seconds, the `get_scene_embeddings` method compute the average of the embedding of a 10-second overlapping windows. Depending on the application, it may be useful to use a pre-trained that can extract embeddings from 20 or 30 seconds without averaging. These variant has pre-trained time positional encoding or 20/30 seconds: ```python @@ -79,8 +88,6 @@ logits = model(wave_signal) Each pre-trained model has a specific frequency/time positional encoding, it's necessary to select the correct input shape to be able to load the models. The important variables for loading are `input_tdim`, `fstride` and `tstride` to specify the spectrograms time frames, the patches stride over frequency, and patches stride over time, respectively. - - ```python import torch @@ -105,7 +112,7 @@ model.net = get_model_passt("passt_20sec", input_tdim=2000) model.net = get_model_passt("passt_30sec", input_tdim=3000) ``` -If you provide the wrong spectrograms, the model may fail silently, by generating low-quality embeddings and logits. Make sure you have the correct spectrograms' config for the selected pre-trained models. +If you provide the wrong spectrograms, the model may fail silently, by generating low-quality embeddings and logits. Make sure you have the correct spectrograms' config for the selected pre-trained models. Models with higher spectrogram resolutions, need to specify the correct spectrogram config: ```python @@ -134,5 +141,3 @@ model.mel = AugmentMelSTFT(n_mels=128, sr=32000, win_length=800, hopsize=100, n_ ``` - - diff --git a/hear21passt/__init__.py b/hear21passt/__init__.py index 4ca51a2..ce2dbaa 100644 --- a/hear21passt/__init__.py +++ b/hear21passt/__init__.py @@ -1,8 +1,6 @@ - -__version__ = "0.0.25" +__version__ = "0.0.26" def embeding_size(hop=50, embeding_size=1000): embedings = 20 * 60 * (1000 / hop) return embedings * embeding_size * 4 / (1024 * 1024 * 1024) # float32 in GB - diff --git a/setup.py b/setup.py index 978fde9..2e97b47 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,5 @@ }, packages=find_packages(exclude=("tests",)), python_requires=">=3.7", - install_requires=["timm==0.4.12", - "torchaudio>=0.7.0"] -) \ No newline at end of file + install_requires=["timm>=0.4.12", "torchaudio>=0.7.0"], +)