-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathvggish.py
98 lines (83 loc) · 3.2 KB
/
vggish.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# adapted from https://github.com/harritaylor/torchvggish
from typing import Tuple
import torch.nn as nn
from torch import hub
import conv
VGGISH_WEIGHTS = (
# "https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish-cbfe8f1c.pth"
'https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish-918c2d05.pth'
)
PCA_PARAMS = (
"https://users.cs.cf.ac.uk/taylorh23/pytorch/models/vggish_pca_params-4d878af3.npz"
)
class VGGishParams:
"""
These should not be changed. They have been added into this file for convenience.
"""
NUM_FRAMES = (96,) # Frames in input mel-spectrogram patch.
NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
EMBEDDING_SIZE = 128 # Size of embedding layer.
# Hyperparameters used in feature and example generation.
SAMPLE_RATE = 16000
STFT_WINDOW_LENGTH_SECONDS = 0.025
STFT_HOP_LENGTH_SECONDS = 0.010
NUM_MEL_BINS = NUM_BANDS
MEL_MIN_HZ = 125
MEL_MAX_HZ = 7500
LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
# Parameters used for embedding postprocessing.
PCA_EIGEN_VECTORS_NAME = "pca_eigen_vectors"
PCA_MEANS_NAME = "pca_means"
QUANTIZE_MIN_VAL = -2.0
QUANTIZE_MAX_VAL = +2.0
"""
VGGish
Input: 96x64 1-channel spectrogram
Output: 128 Embedding
"""
class VGGish(nn.Module):
def __init__(self, feature_extract: bool):
super(VGGish, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, VGGishParams.NUM_BANDS, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(VGGishParams.NUM_BANDS, VGGishParams.EMBEDDING_SIZE, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(256, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
self.embeddings = nn.Sequential(
nn.Linear(512 * 24, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, VGGishParams.EMBEDDING_SIZE),
nn.ReLU(inplace=True),
)
conv.set_parameter_requires_grad(self.features, feature_extract)
conv.set_parameter_requires_grad(self.embeddings, feature_extract)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.embeddings(x)
return x
def vggish(feature_extract: bool) -> Tuple[VGGish, int]:
"""
VGGish is a PyTorch implementation of Tensorflow's VGGish architecture used to create embeddings
for Audioset. It produces a 128-d embedding of a 96ms slice of audio. Always comes pretrained.
"""
model = VGGish(feature_extract)
model.load_state_dict(hub.load_state_dict_from_url(VGGISH_WEIGHTS), strict=True)
return model, VGGishParams.EMBEDDING_SIZE