Skip to content
This repository was archived by the owner on Feb 1, 2025. It is now read-only.

Commit

Permalink
Move to loading of the Postprocessor out of itself
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenguh committed Feb 25, 2020
1 parent 1bedc9d commit 48521ec
Showing 1 changed file with 26 additions and 35 deletions.
61 changes: 26 additions & 35 deletions torchvggish/vggish.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,41 +44,20 @@ class Postprocessor(nn.Module):
the same PCA (with whitening) and quantization transformations."
"""

def __init__(self, pretrained, params_url, progress=True):
def __init__(self):
"""Constructs a postprocessor."""
super(Postprocessor, self).__init__()
if pretrained:
self.init_params_pth_url(params_url, progress=progress)
else:
# Create empty matrix, for user's state_dict to load
self.pca_matrix = torch.empty(
(vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,),
dtype=torch.float,
)
self.pca_means = torch.empty(
(vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float
)

self.pca_matrix = nn.Parameter(self.pca_matrix, requires_grad=False)
self.pca_means = nn.Parameter(self.pca_means, requires_grad=False)

def init_params_pth_url(self, pca_params_dict_url, progress=True):
params = hub.load_state_dict_from_url(pca_params_dict_url, progress=progress)
self.pca_matrix = torch.as_tensor(
params[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float
# Create empty matrix, for user's state_dict to load
self.pca_eigen_vectors = torch.empty(
(vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,),
dtype=torch.float,
)
self.pca_means = torch.as_tensor(
params[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float
self.pca_means = torch.empty(
(vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float
)

assert self.pca_matrix.shape == (
vggish_params.EMBEDDING_SIZE,
vggish_params.EMBEDDING_SIZE,
), "Bad PCA matrix shape: %r" % (self.pca_matrix.shape,)
assert self.pca_means.shape == (
vggish_params.EMBEDDING_SIZE,
1,
), "Bad PCA means shape: %r" % (self.pca_means.shape,)
self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False)
self.pca_means = nn.Parameter(self.pca_means, requires_grad=False)

def postprocess(self, embeddings_batch):
"""Applies tensor postprocessing to a batch of embeddings.
Expand All @@ -105,7 +84,7 @@ def postprocess(self, embeddings_batch):
# - Premultiply by PCA matrix of shape [output_dims, input_dims]
# where both are are equal to embedding_size in our case.
# - Transpose result back to [batch_size, embedding_size].
pca_applied = torch.mm(self.pca_matrix, (embeddings_batch.T - self.pca_means)).T
pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.T - self.pca_means)).T

# Quantize by:
# - clipping to [min, max] range
Expand Down Expand Up @@ -162,15 +141,27 @@ def _vgg():


class VGGish(VGG):
def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True):
def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True, progress=True):
super().__init__(make_layers())
if pretrained:
state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=True)
state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=progress)
super().load_state_dict(state_dict)

self.preprocess = preprocess
self.postprocess = postprocess
self.pproc = Postprocessor(urls['pca'])
if self.preprocess:
self.pproc = Postprocessor()
if pretrained:
state_dict = hub.load_state_dict_from_url(urls['pca'], progress=progress)
# TODO: Convert the state_dict to torch
state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor(
state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float
)
state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor(
state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float
)

self.pproc.load_state_dict(state_dict)

def forward(self, x, fs=None):
if self.preprocess:
Expand All @@ -190,4 +181,4 @@ def _preprocess(self, x, fs):
return x

def _postprocess(self, x):
return self.pproc.postprocess(x)
return self.pproc(x)

0 comments on commit 48521ec

Please sign in to comment.